import time
import warnings
from typing import Optional

import castle
import castle.algorithms
import networkx as nx
import numpy as np
import pandas as pd
from omegaconf import DictConfig
from utils.graph import get_dag_from_causal_graph, save_causal_graph

warnings.filterwarnings("ignore", category=FutureWarning)

def learn_structure(
    cfg: DictConfig, data: pd.DataFrame, int_table: pd.DataFrame, save_path: Optional[str] = None, seed: int = 0
) -> nx.Graph:

    method = cfg.causal_discovery.lower()

    labels = list(data.columns)
    data = data.to_numpy()
    data = add_noise(data)

    t0 = time.time()
    match method:
        case "pc":
            method = castle.algorithms.PC(ci_test="chi2")
        case "pc_stable":
            method = castle.algorithms.PC(variant="stable", ci_test="chi2")
        case "ges":
            method = castle.algorithms.GES(criterion="bdeu")
        case "lingam":
            method = castle.algorithms.DirectLiNGAM()
        case "notears":
            method = castle.algorithms.Notears(max_iter = 100)
        case "dag_gnn":
            method = castle.algorithms.DAG_GNN(batch_size = 1024, device_type="gpu", seed= seed)
        case "golem":
            method = castle.algorithms.GOLEM(seed = seed, device_type="gpu")
        case "grandag":
            method = castle.algorithms.GraNDAG(input_dim = data.shape[1], batch_size=1024, hidden_num= 4, hidden_dim= 64, device_type="gpu")
        case "random":
            # Parameters computed via:
            # p = sum(in_degrees)/(num_nodes*(num_nodes - 1))
            # with in_degrees and num_nodes from the original ground truth graph

            if cfg.dataset == "dataset0":
                p = 0.04390420899854862
            elif cfg.dataset == "dataset1":
                p = 0.02060447544318512 
            else:
                p = 0.3

            causal_graph = nx.adjacency_matrix(nx.erdos_renyi_graph(data.shape[1], p, seed=seed, directed=True)).todense()
        case _:
            raise ValueError(f"Causal Discovery method {method} is not supported.")

    if method != "random":
        method.learn(data)
        causal_graph = method.causal_matrix

    t1 = time.time()
    delta_t = t1 - t0
    print(f"Causal Discovery method finished in {delta_t} seconds.")
          
    # Initialize a directed graph from the adjacency matrix
    G = nx.from_numpy_array(causal_graph, create_using=nx.DiGraph)

    # Optional: Relabel nodes to match original column names if necessary
    mapping = dict(zip(G.nodes(), labels))
    G = nx.relabel_nodes(G, mapping)

    # Create a copy of the graph. Done for compatibility reasons.
    new_graph = nx.DiGraph()
    new_graph.add_nodes_from(G.nodes())
    new_graph.add_edges_from(G.edges())

    save_causal_graph(G, save_path, "learned_causal_graph_raw")

    # Some algorithms may not return a DAG. If not, remove edges to make it a DAG.
    if not nx.is_directed_acyclic_graph(G):
        print("The learned graph is not a DAG.Removing edges to make it a DAG.")
        G = get_dag_from_causal_graph(G)
        # Convert pgmpy graph to networkx graph
        G = nx.DiGraph(G)

    return new_graph, delta_t


def add_noise(data: pd.DataFrame) -> pd.DataFrame:
    """
    Adds a small Gaussian noise to columns with constant values in the given dataframe.

    Args:
        data (pd.DataFrame): The input dataframe.

    Returns:
        pd.DataFrame: The dataframe with added noise.
    """
    noisy_data = data.copy()
    for column in range(noisy_data.shape[1]):
        if np.unique(noisy_data[:, column]).size == 1:
            noise = np.random.normal(0, 0.1, noisy_data.shape[0])
            noisy_data[:, column] += noise

    return noisy_data